import os
import sys
import gzip
import pysam
from Bio import SeqIO


library = sys.argv[1]

assembly = 'hg38'

def read_chromosome_sizes(assembly):
    directory = "/osc-fs_home/scratch/mdehoon/Data/Genomes"
    filename = "%s.chrom.sizes" % assembly
    path = os.path.join(directory, assembly, filename)
    handle = open(path)
    chromosomes = []
    sizes = []
    for line in handle:
        chromosome, size = line.split()
        if chromosome.endswith("_alt"):
            continue
        chromosomes.append(chromosome)
        sizes.append(int(size))
    handle.close()
    return chromosomes, sizes

def generate_unmapped_alignments(library):
    filename = "%s.index.txt" % library
    stream = open(filename)
    for line in stream:
        query_name, read1_read2 = line.split()
        read1, read2 = read1_read2.split(",")
        assert read1.startswith("READ1_")
        assert read2.startswith("READ2_")
        number1 = int(read1[6:])
        number2 = int(read2[6:])
        alignment1 = pysam.AlignedSegment()
        alignment1.query_name = query_name
        alignment1.is_unmapped = True
        alignment1.is_read1 = True
        alignment1.is_read2 = False
        alignment1.is_paired = True
        alignment2 = pysam.AlignedSegment()
        alignment2.query_name = query_name
        alignment2.is_unmapped = True
        alignment2.is_read1 = False
        alignment2.is_read2 = True
        alignment2.is_paired = True
        yield alignment1
        yield alignment2
    stream.close()

def get_genome_location(alignment):
    alignment1, alignment2 = alignment
    chromosome = alignment1.reference_name
    assert chromosome == alignment2.reference_name
    assert alignment1.is_read1
    assert not alignment1.is_read2
    assert not alignment2.is_read1
    assert alignment2.is_read2
    if alignment1.is_reverse:
        assert not alignment2.is_reverse
        start = alignment2.reference_start
        end = alignment1.reference_end
        strand = "-"
    else:
        assert alignment2.is_reverse
        start = alignment1.reference_start
        end = alignment2.reference_end
        strand = "+"
    if start >= end:
        print(alignment1.is_reverse, alignment2.is_reverse)
        print(alignment1)
        print(alignment2)
    assert start < end
    cigar = (alignment1.cigar, alignment2.cigar)
    return (chromosome, start, end, strand, cigar)

def get_transcripts(alignments):
    transcripts = []
    for alignment in alignments:
         alignment1, alignment2 = alignment
         transcript = alignment1.get_tag("XR")
         transcripts.append(transcript)
    transcripts = sorted(set(transcripts))
    return ",".join(transcripts)

def write_alignments(output, alignments, chromosomes, target):
    if target == "unmapped":
        assert len(alignments) == 1
        alignment = alignments[0]
        alignment1, alignment2 = alignment
        output.write(alignment1)
        output.write(alignment2)
        return
    shortest_length = None
    if target == "genome":
        selected_alignments = []
        for alignment in alignments:
            alignment1, alignment2 = alignment
            if alignment1.is_unmapped:
                assert alignment2.is_unmapped
                if shortest_length is None:
                    selected_alignments.append(alignment)
            else:
                assert not alignment2.is_unmapped
                if alignment1.is_reverse:
                    assert not alignment2.is_reverse
                    length = alignment1.reference_end - alignment2.reference_start
                else:
                    assert alignment2.is_reverse
                    length = alignment2.reference_end - alignment1.reference_start
                
                if shortest_length is None or length < shortest_length:
                    shortest_length = length
                    selected_alignments.clear()
                selected_alignments.append(alignment)
        alignments = selected_alignments
    else:
        for alignment in alignments:
            alignment1, alignment2 = alignment
            length = alignment1.get_tag("XL")
            assert length is not None
            if shortest_length is None:
                shortest_length = length
            else:
                assert shortest_length == length
    if any(alignment1.reference_name in chromosomes
           for (alignment1, alignment2) in alignments):
        alignments = [(alignment1, alignment2)
                      for (alignment1, alignment2)  in alignments
                      if alignment1.reference_name in chromosomes]
        alignments.sort(key=get_genome_location)
        current = None
        block = []
        for alignment in alignments:
            location = get_genome_location(alignment)
            if location != current:
                if block:
                    alignment1, alignment2 = block[0]
                    if target in ("snRNA", "scRNA", "snoRNA", "scaRNA",
                                  "mRNA", "lncRNA", "gencode", "fantomcat"):
                        transcripts = get_transcripts(block)
                        alignment1.set_tag("XR", transcripts)
                    assert alignment1.get_tag("XT") == target
                    output.write(alignment1)
                    output.write(alignment2)
                    block.clear()
                current = location
            block.append(alignment)
        if block:
            alignment1, alignment2 = block[0]
            if target in ("snRNA", "scRNA", "snoRNA", "scaRNA",
                          "mRNA", "lncRNA", "gencode", "fantomcat"):
                transcripts = get_transcripts(block)
                alignment1.set_tag("XR", transcripts)
            assert alignment1.get_tag("XT") == target
            output.write(alignment1)
            output.write(alignment2)
    else:
        alignment1, alignment2 = alignment
        if target in ("snRNA", "scRNA", "snoRNA", "scaRNA",
                      "mRNA", "lncRNA", "gencode", "fantomcat"):
            transcripts = get_transcripts(alignments)
            alignment1.set_tag("XR", transcripts)
        alignment1.is_unmapped = True
        alignment2.is_unmapped = True
        assert alignment1.get_tag("XT") == target
        output.write(alignment1)
        output.write(alignment2)

chromosomes, sizes = read_chromosome_sizes(assembly)

directory = "/osc-fs_home/mdehoon/Data/CASPARs/MiSeq"
subdirectory = "Fastq"
subdirectory = os.path.join(directory, subdirectory)
filename1 = "%s_READ1.fq.gz" % library
filename2 = "%s_READ2.fq.gz" % library
path1 = os.path.join(subdirectory, filename1)
path2 = os.path.join(subdirectory, filename2)
stream1 = gzip.open(path1, "rt")
stream2 = gzip.open(path2, "rt")
records1 = SeqIO.parse(stream1, "fastq")
records2 = SeqIO.parse(stream2, "fastq")

subdirectory = "BAM"
subdirectory = os.path.join(directory, subdirectory)
filenames = os.listdir(subdirectory)

filename = "%s.bam" % library
print("Writing", filename)
output = pysam.AlignmentFile(filename, "wb", reference_names=chromosomes, reference_lengths=sizes)

alignments = {}
for filename in filenames:
    terms = filename.split(".")
    if len(terms) != 3:
        continue
    if terms[0] != library:
        continue
    if terms[2] != 'bam':
        continue
    target = terms[1]
    path = os.path.join(subdirectory, filename)
    print("Reading", path)
    alignments[target] = pysam.AlignmentFile(path)
    # alignments[target].header may have additional keys beyond those in
    # output.header.
    for key in output.header.keys():
        assert output.header[key] == alignments[target].header[key]

alignments['unmapped'] = generate_unmapped_alignments(library)

cached = {}
for target in alignments:
    try:
        alignment1 = next(alignments[target])
    except StopIteration:
        print("No alignments for %s" % target)
        continue
    alignment2 = next(alignments[target])
    query_name1 = alignment1.query_name
    query_name2 = alignment2.query_name
    assert query_name1 == query_name2
    query_name = query_name1
    alignment = (alignment1, alignment2)
    cached[target] = (query_name, alignment)

for record1, record2 in zip(records1, records2):
    assert record1.id == record2.id
    instrument, run_number, flowcell_ID, lane, tile, x_pos, y_pos = record1.id.split(":")
    assert instrument == "M00528"
    assert run_number == "115"
    assert flowcell_ID  == "000000000-AD0NU"
    assert lane == "1"
    tile = int(tile)
    x_pos = int(x_pos)
    y_pos = int(y_pos)
    for target in cached:
        query_name, alignment = cached[target]
        if record1.id == query_name:
            break
    else:
        raise Exception("Failed to find alignment for %s" % record1.id)
    current_alignments = [alignment]
    for alignment1 in alignments[target]:
        alignment2 = next(alignments[target])
        query_name1 = alignment1.query_name
        query_name2 = alignment2.query_name
        assert query_name1 == query_name2
        alignment = (alignment1, alignment2)
        if query_name1 != query_name:
            break
        current_alignments.append(alignment)
    else:
        alignments[target].close()
        del cached[target]
        alignment = None
    write_alignments(output, current_alignments, chromosomes, target)
    if alignment:
        cached[target] = (query_name1, alignment)
output.close()
stream1.close()
stream2.close()
